Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixed precision training to TorchEngine #1322

Merged
merged 1 commit into from
May 15, 2023
Merged

Add mixed precision training to TorchEngine #1322

merged 1 commit into from
May 15, 2023

Conversation

JackTemaki
Copy link
Collaborator

@JackTemaki JackTemaki commented May 4, 2023

Uses torch_amp as config dict with dtype option.

Adds GradScaler to engine, and applies autocast and the scaler during training if amp is enabled. Uses grad_scaler as config option to explicitly configure it.

Fixes #1334.

uses torch_amp_options as config dict with "dtype" option.
Adds GradScaler to engine, and applies autocast and the scaler during training if
amp is enabled.
@JackTemaki JackTemaki requested review from a team and albertz as code owners May 4, 2023 14:43
@albertz albertz requested a review from Icemole May 4, 2023 15:04
Copy link
Member

@albertz albertz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine.

@albertz
Copy link
Member

albertz commented May 4, 2023

I wonder, in what cases would you want that the model params also use the same dtype? I have read that when using bfloat16, you usually always want that the params are also stored in bfloat16.

With float16 AMP training, is it normal that you would keep the params in float32?

What about TensorFloat32?

So I wonder if this is sth which should also automatically be handled via torch_amp_options or whether that should be a separate option. And I wonder how people usually do it.

@albertz
Copy link
Member

albertz commented May 15, 2023

I just merge this as an initial version now.

Can you comment on my questions?

@albertz albertz merged commit 003b2e8 into master May 15, 2023
@albertz albertz deleted the nick-add-amp branch May 15, 2023 07:35
@albertz
Copy link
Member

albertz commented May 15, 2023

Another question: Was it by intention that you allowed that the user does not specify dtype? I also wonder, if you use dtype=None, what exactly is the behavior of autocast?

@JackTemaki
Copy link
Collaborator Author

With float16 AMP training, is it normal that you would keep the params in float32?

When entering an autocast-enabled region, Tensors may be any type. You should not call half() or bfloat16() on your model(s) or inputs when using autocasting.
https://pytorch.org/docs/stable/amp.html#torch.autocast

So yes, you should not do anything to the model explicitly because autocast is handling that.

Another question: Was it by intention that you allowed that the user does not specify dtype? I also wonder, if you use dtype=None, what exactly is the behavior of autocast?

No, this is a mistake, it should be given. I do not know what the behavior is then, but likely not what is intended.

@albertz
Copy link
Member

albertz commented May 15, 2023

So yes, you should not do anything to the model explicitly because autocast is handling that.

Autocast is automatically casting inputs to certain PyTorch ops. The parameters are not changed, they are just casted automatically for those ops. But this is not really my question. My question is: Wouldn't it make more sense to directly have the parameters in float16?

albertz added a commit that referenced this pull request May 15, 2023
albertz added a commit that referenced this pull request May 15, 2023
Follow-up to #1322

Rename torch_amp_options to torch_amp.

Allow simply `torch_amp = 'float16'` in config.

Allow to specify grad_scaler separately.
@albertz
Copy link
Member

albertz commented May 15, 2023

I renamed torch_amp_options just to torch_amp.

I was even thinking about renaming it to just amp, or maybe also autocast. We also don't have torch_ prefix for other things, and you might want to use this also for other backends. For example, jmp implements automatic mixed precision training for JAX.

@JackTemaki
Copy link
Collaborator Author

So yes, you should not do anything to the model explicitly because autocast is handling that.

Autocast is automatically casting inputs to certain PyTorch ops. The parameters are not changed, they are just casted automatically for those ops. But this is not really my question. My question is: Wouldn't it make more sense to directly have the parameters in float16?

I see no indication why, unless you really want your whole network to run in float16.

@albertz
Copy link
Member

albertz commented May 15, 2023

My question is: Wouldn't it make more sense to directly have the parameters in float16?

I see no indication why

Because that further reduce the memory requirement? Why would you not want that? What are the downsides?

I'm not saying that everything should be float16. Maybe certain ops must stay in float32. I thought this is the main aspect of autocast/AMP, to cast to float16 wherever it makes sense.

I just don't understand why weights are stored in float32, and then always auto-casted. That also adds some overhead in computation (the casting), and requires more memory. Unless there is maybe some reason. But that is my question, what is the reason for this?

@albertz
Copy link
Member

albertz commented May 15, 2023

Ah, I was just checking the original paper introducing automatic mixed precision training, and it explains it (Sec 3.1):

In mixed precision training, weights, activations and gradients are stored as FP16. In order to match the accuracy of the FP32 networks, an FP32 master copy of weights is maintained and updated with the weight gradient during the optimizer step. In each iteration an FP16 copy of the master weights is used in the forward and backward pass. ...

While the need for FP32 master weights is not universal, there are two possible reasons why a number of networks require it. One explanation is that updates (weight gradients multiplied by the learning rate) become too small to be represented in FP16 - any value whose magnitude is smaller than $2^{−24}$ becomes zero in FP16. ...

Another explanation is that the ratio of the weight value to the weight update is very large. In this case, even though the weight update is representable in FP16, it could still become zero when addition operation right-shifts it to align the binary point with the weight. ...

JackTemaki added a commit to JackTemaki/MiniReturnn that referenced this pull request May 16, 2023
uses torch_amp_options as config dict with "dtype" option.
Adds GradScaler to engine, and applies autocast and the scaler during training if
amp is enabled.
@JackTemaki
Copy link
Collaborator Author

Now that the code is longer, we might want to move this into the updater or add an extra module instead of having it plain in the engine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PyTorch automatic mixed precision (AMP) support
2 participants